Skip to content

NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093

Open
cael-ling wants to merge 4 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache
Open

NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093
cael-ling wants to merge 4 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache

Conversation

@cael-ling

Copy link
Copy Markdown
Contributor

Description

For block-scaled NVFP4, a cached weight is used in two GEMMs per step — fprop (row-wise scales) and dgrad (column-wise scales) — and each GEMM needs its scale factors in the GEMM-swizzled layout. Today that swizzle is recomputed lazily inside general_gemm on every micro-batch and thrown away, so with N micro-batches the weight scale swizzle runs 2*N times per step even though the weight is quantized only once, which hurts performance. (Activation quantizers already set optimize_for_gemm=True and were pre-swizzled; only the weight was missed.)

This PR sets weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the swizzle is done once at quantize time, persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every GEMM — 2*N2 swizzles per step.

  • Applied to Linear, LayerNormLinear, LayerNormMLP (fc1 + fc2) and GroupedLinear (per expert).

  • Gated to the cached path (is_first_microbatch is not None) with fsdp_group is None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled scale layout, so pre-swizzling is unsupported there.

  • No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8).

  • Swizzling is a pure layout permutation, so numerics are unchanged.

  • New tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py: asserts the cached eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for Linear / LayerNormLinear / GroupedLinear, and that _with_gemm_swizzled_scales is set and persisted on the cached workspace.

  • pytest tests/pytorch/test_numerics.py -k "linear or layernorm or mlp" — no regressions.

  • pytest tests/pytorch/test_grouped_linear.py -k "not grouped_tensor and not fused_path" — no regressions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…obatches

For block-scaled NVFP4 a cached weight participates in two GEMMs per step:
fprop (rowwise scales) and dgrad (columnwise scales). The GEMM-ready scale
swizzle was recomputed lazily inside every GEMM and discarded, so with N
microbatches the weight scale swizzle ran 2*N times per step even though the
weight is quantized only once.

Because weight RHT is disabled, the weight scales are not swizzled by the
cast-fusion path; with optimize_for_gemm off they also skip the post-quantize
fallback swizzle, so the only swizzle site left for the weight is the lazy one
inside general_gemm (swizzle_scales_for_gemm), which re-runs on every GEMM.
(Activation input/grad_output quantizers already set optimize_for_gemm=True, so
they were pre-swizzled via cast-fusion/fallback; only the weight was missed.)

Set weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the
swizzle is done once at quantize time (via the post-quantize fallback),
persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused
by every GEMM (swizzle_scales_for_gemm early-returns) -> 2 swizzles per step
instead of 2*N. Applied to Linear, LayerNormLinear, LayerNormMLP (fc1+fc2) and
GroupedLinear (per expert).

Gated to the cached path (is_first_microbatch is not None) with fsdp_group is
None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled
scale layout, so pre-swizzling is unsupported there. No-op for recipes whose
scales do not require swizzling (e.g. per-tensor FP8). Swizzling is a pure
layout permutation, so numerics are unchanged.

Add tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py verifying the cached
eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for
Linear/LayerNormLinear/GroupedLinear and that the swizzled flag is persisted.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner June 5, 2026 14:29
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps

greptile-apps Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR eliminates redundant GEMM-swizzle work for NVFP4 block-scaled weights by setting optimize_for_gemm=True on the weight quantizer whenever the quantized weight is cached across micro-batches. The swizzle is performed once at quantize time, persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every subsequent micro-batch GEMM — reducing 2*N swizzle kernels per step to 2 (one fprop, one dgrad).

  • Applied consistently to Linear, LayerNormLinear, and LayerNormMLP using a symmetric weight_quantizer.optimize_for_gemm = cache_name is not None assignment that correctly resets the flag to False on uncached calls; FSDP2 is excluded via the pre-existing cache_name gate.
  • Applied to GroupedLinear (per-expert) with cache_weight = is_first_microbatch is not None; unlike the other three modules, the flag is only set to True and never explicitly reset to False, leaving a minor behavioral inconsistency when switching from cached to uncached mode.
  • New test file tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py covers all four module kinds for both numeric equivalence and workspace flag persistence, and guards against accidentally enabling the optimization on the uncached path.

Confidence Score: 5/5

Safe to merge; the change is a pure performance optimization — swizzling is a layout permutation, so numerics are unchanged across all code paths.

The optimization is gated identically to the pre-existing weight-caching logic in all four modules. For Linear, LayerNormLinear, and LayerNormMLP the flag is unconditionally written on every forward call so it can never get stuck in the wrong state. The only asymmetry is in GroupedLinear, where the flag is set to True but never explicitly reset to False when caching is disabled — however, the lazy-swizzle fallback inside general_gemm produces the same output, so correctness is preserved.

grouped_linear.py — the optimize_for_gemm flag is not reset to False in the uncached branch, unlike the other three changed modules.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Adds weight_quantizer.optimize_for_gemm = cache_name is not None before the GEMM call, unconditionally setting the flag to True or False; correctly gated by the existing FSDP2/is_first_microbatch check in cache_name.
transformer_engine/pytorch/module/layernorm_linear.py Mirrors linear.py: unconditionally sets weight_quantizer.optimize_for_gemm = cache_name is not None, correctly gated through the pre-existing FSDP2 guard in cache_name.
transformer_engine/pytorch/module/layernorm_mlp.py Sets fc1_weight_quantizer.optimize_for_gemm and fc2_weight_quantizer.optimize_for_gemm independently via their respective cache_name_fc{1,2} variables; symmetric and correct.
transformer_engine/pytorch/module/grouped_linear.py Sets optimize_for_gemm=True only when cache_weight=True; unlike the other three modules, it never resets the flag to False when switching to uncached mode — a minor behavioral inconsistency with the rest of the change.
tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py New test file covering all four module kinds (Linear, LayerNormLinear, LayerNormMLP, GroupedLinear) for both the numerics and the workspace-flag assertion; also covers the lazy path guard.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Caller
    participant Module as Linear/LNLinear/LNmlp/GroupedLinear
    participant Quantizer as WeightQuantizer
    participant Cache as _fp8_workspaces
    participant GEMM as general_gemm

    Note over Caller,GEMM: First micro-batch (is_first_microbatch=True)
    Caller->>Module: "forward(x, is_first_microbatch=True)"
    Module->>Quantizer: "optimize_for_gemm = True"
    Module->>Cache: get("weight") → None (miss)
    Module->>Quantizer: "quantize_(weight) → workspace (_with_gemm_swizzled_scales=True)"
    Module->>Cache: store("weight", workspace)
    Module->>GEMM: gemm(x, workspace) — scales already swizzled

    Note over Caller,GEMM: Subsequent micro-batch (is_first_microbatch=False)
    Caller->>Module: "forward(x, is_first_microbatch=False)"
    Module->>Quantizer: "optimize_for_gemm = True"
    Module->>Cache: get("weight") → workspace (hit, pre-swizzled)
    Note over Module,GEMM: update_workspace=False, no re-quantize
    Module->>GEMM: gemm(x, workspace) — reuses swizzled cache

    Note over Caller,GEMM: Uncached path (is_first_microbatch=None)
    Caller->>Module: "forward(x, is_first_microbatch=None)"
    Module->>Quantizer: "optimize_for_gemm = False"
    Module->>Quantizer: "quantize_(weight) → fresh workspace (_with_gemm_swizzled_scales=False)"
    Module->>GEMM: gemm(x, workspace) — lazy swizzle inside GEMM
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Caller
    participant Module as Linear/LNLinear/LNmlp/GroupedLinear
    participant Quantizer as WeightQuantizer
    participant Cache as _fp8_workspaces
    participant GEMM as general_gemm

    Note over Caller,GEMM: First micro-batch (is_first_microbatch=True)
    Caller->>Module: "forward(x, is_first_microbatch=True)"
    Module->>Quantizer: "optimize_for_gemm = True"
    Module->>Cache: get("weight") → None (miss)
    Module->>Quantizer: "quantize_(weight) → workspace (_with_gemm_swizzled_scales=True)"
    Module->>Cache: store("weight", workspace)
    Module->>GEMM: gemm(x, workspace) — scales already swizzled

    Note over Caller,GEMM: Subsequent micro-batch (is_first_microbatch=False)
    Caller->>Module: "forward(x, is_first_microbatch=False)"
    Module->>Quantizer: "optimize_for_gemm = True"
    Module->>Cache: get("weight") → workspace (hit, pre-swizzled)
    Note over Module,GEMM: update_workspace=False, no re-quantize
    Module->>GEMM: gemm(x, workspace) — reuses swizzled cache

    Note over Caller,GEMM: Uncached path (is_first_microbatch=None)
    Caller->>Module: "forward(x, is_first_microbatch=None)"
    Module->>Quantizer: "optimize_for_gemm = False"
    Module->>Quantizer: "quantize_(weight) → fresh workspace (_with_gemm_swizzled_scales=False)"
    Module->>GEMM: gemm(x, workspace) — lazy swizzle inside GEMM
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +67 to +72
with te.autocast(enabled=True, recipe=recipe):
out = module(x, is_first_microbatch=is_first)
out.sum().backward()
return out.detach().float(), x.grad.detach().float()


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing LayerNormMLP test coverage

layernorm_mlp.py is one of four files modified by this PR, yet the test suite parametrizes only over ["Linear", "LayerNormLinear"] for both test_weight_swizzle_cache_numerics and test_lazy_path_not_swizzled. The fc1/fc2 two-quantizer path in LayerNormMLP is structurally different from the single-quantizer modules: it independently gates fc1_weight_quantizer.optimize_for_gemm and fc2_weight_quantizer.optimize_for_gemm using separate cache_name_fc1/cache_name_fc2 variables. If either gating expression were wrong (e.g. swapping fc1/fc2 names), existing tests would not catch it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added LayerNormMLP coverage (fc1+fc2 two-quantizer path) to both parametrized tests.

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from FSDP2 condition being irrelevant, LGTM

Comment on lines +178 to +185
@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays off — guards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(kind, 1024, 1024, device)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays offguards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(kind, 1024, 1024, device)
@pytest.mark.parametrize("layer_type", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays offguards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(layer_type, 1024, 1024, device)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — test_lazy_path_not_swizzled now parametrizes over all four module kinds.

x = x.detach().clone().requires_grad_(True)
module.zero_grad(set_to_none=True) # per-micro-batch grads (no accumulation)
with te.autocast(enabled=True, recipe=recipe):
out = module(x, is_first_microbatch=is_first)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If absence of m_splits argument is the only reason for creating new test for grouped_linear below, then can we add a check on the module in terms of passing m_splits only if module is GroupedLinear, instead of duplicating the test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — folded GroupedLinear into the parametrized test_weight_swizzle_cache_numerics, passing m_splits only for GroupedLinear (in _step); removed the duplicated grouped-only test.

Comment on lines +1738 to +1749
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert. Gated to the
# cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled
# scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is
# unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight and self.fsdp_group is None and not self.is_fsdp2:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True

@vthumbe1503 vthumbe1503 Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think the comment is relevant In case of FSDP/FSDP2,
For FSDP, The scales are not sharded, and the whole scales are replicated across ranks today. So it doesnt matter if scales are swizzled or not. cc: @denera. Also NVFP4 pre allgather is function specific to FSDP2 not FSDP.
For FSDP2, we havent been caching weights as it causes memory bloating. And weight caching as a mechanism doesnt fit well with fsdp2. This was done for linear and layer_norm_linear but apparently not for grouped_linear in this PR #2805. But fixing that for grouped_linear might be byond scope of this PR. Even if weight caching is still kept as it is, current behavior is to save the entire weight instead of shard in the workspace and so swizzling being present shouldnt cause any issue.

Suggested change
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert. Gated to the
# cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled
# scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is
# unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight and self.fsdp_group is None and not self.is_fsdp2:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert.
# No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same applies in other files.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied to all four module files.

cael-ling and others added 2 commits June 17, 2026 00:36
Drop the FSDP/FSDP2 gating on optimize_for_gemm in Linear, LayerNormLinear,
LayerNormMLP and GroupedLinear. FSDP1 replicates (does not shard) the scale
factors, so the swizzle layout is irrelevant there, and weights are not cached
under FSDP2; the guard only added a misleading comment and dead conditions.
Pre-swizzle the weight scales whenever the quantized weight is cached.

Tests:
- Fold the GroupedLinear case into the parametrized
  test_weight_swizzle_cache_numerics by passing m_splits only for
  GroupedLinear, removing the duplicated grouped-only test.
- Add LayerNormMLP coverage (fc1 + fc2 two-quantizer path), generalizing
  the cached-workspace-count assertion per module type.
- Parametrize test_lazy_path_not_swizzled over all four module kinds.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling

Copy link
Copy Markdown
Contributor Author

Pushed a commit addressing the review: removed the irrelevant FSDP gating across all four modules, merged the GroupedLinear test, and added LayerNormMLP coverage. Please take a look, thanks. @vthumbe1503

@cael-ling cael-ling requested a review from vthumbe1503 June 17, 2026 07:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants